"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.



compute the autocovariance integral for while keeping the result sparse.

the file `int_start`_to_`int_stop`.pickle contains the autocovariance integral between
t1 and t2, where t1 is the *start time* of int_start and t2 is also the *start time* 
of int_stop. (also valid when int_start > int_stop)

The value saved is :

    int_t1^t2 T(t_1,t) P(t)^{-1} T(t_1,t)^T dt 
    

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np
import time
from multiprocessing import Pool
from argparse import ArgumentParser
import hickle

import pandas as pd
from scipy.sparse import eye, csr_matrix

from TemporalNetwork import ContTempNetwork, set_to_zeroes
from SparseStochMat import (inplace_csr_row_normalize,
                            inplace_csr_matmul_diag, inplace_diag_matmul_csr,
                            sparse_gram_matrix, sparse_matmul)

import traceback
import re
import gc


# raise Exception

#%%

ap = ArgumentParser()

ap.add_argument('--datadir', default='', type=str,
                help="location of the inter trans. mats computed for each intervals")

ap.add_argument('--savedir', default='', type=str,
                help="where the results will be saved.")

ap.add_argument('--p1_file', default='', type=str,
                help=("location and filename of the initial probability distribution. "
                     "Default is a uniform distribution."))

ap.add_argument('--pk_init_dir', default='', type=str,
                help=("location of the initial probability distributions. For when "
                     "the init. dist. was computed during a previous run."))



ap.add_argument('--ncpu', default=4, type=int)
ap.add_argument('--nnode', default=1, type=int)
ap.add_argument('--node_num', default=0, type=int)
ap.add_argument('--net_name', default='tempnet', type=str)

ap.add_argument('--num_points', default=50, type=int)

ap.add_argument('--int_length',default=None, type=int,
                help="Length of a single interval. Used to set the number of intervals instead of num_points")

ap.add_argument('--int_list', default=[], type=int, nargs='+',
                help="List intervals used for the integral. Used instead of num_points or int_length.")

ap.add_argument('--t_s', default=10, type=float)

ap.add_argument('--tol', default=None, type=float,
                help='Values smaller than  max(val)*tol are set to zero in sparse transition matrices.')

ap.add_argument('--integral_rtol', default=None, type=float,
                help='Values smaller than max(integral)*rtol are set to zero in the sparse integral.')

ap.add_argument('--reverse_time', action='store_true',
                help="To compute the forward diffusion but with reversed time.")


ap.add_argument('--only_one_interval', action='store_true',
                help="instead of computing every combinations of start and finish, will compute from every start but only for one interval.")


ap.add_argument('--verbose', action='store_true',)

ap.add_argument('--print_mem_usage', action='store_true',
                help="print memory usage. Requires psutil module.")

ap.add_argument('--print_interval', default=100, type=int,
                help="Controls how often memory usage is printed.")


ap.add_argument('--save_intermediate', action='store_true',
                help="Saves, and loads, intermediate steps in order to be able to restart computation.")

ap.add_argument('--save_each_step', action='store_true',
                help="Saves, and loads, each steps (even more than save_intermediate) in order to be able to restart computation.")

ap.add_argument('--tau_w_list', default=[], type=float, nargs='+',
                help="list of waiting times to analyze. given as '(tau1 tau2 ...)'")
                
ap.add_argument('--pk_dir', default='', type=str,
                help="Where to save the p_k vectors, computed for each inter event times.")
                
ap.add_argument('--save_pk_vecs', action='store_true',
                help="Whether to save the p_k vectors needed for the backward compuation.")
                


inargs = vars(ap.parse_args())
datadir = inargs['datadir']
savedir = inargs['savedir']
p1_file = inargs['p1_file']
ncpu = inargs['ncpu']
nnode = inargs['nnode']
node_num = inargs['node_num']
net_name = inargs['net_name']
num_points = inargs['num_points']
int_length = inargs['int_length']
int_list = inargs['int_list']
t_s = inargs['t_s']
tol = inargs['tol']
integral_rtol = inargs['integral_rtol']

time_direction = 'forward'

reverse_time = inargs['reverse_time']

only_one_interval = inargs['only_one_interval']

verbose = inargs['verbose']

save_intermediate = inargs['save_intermediate']
save_each_step = inargs['save_each_step']

use_expm_transmats = True


print_mem_usage = inargs['print_mem_usage']
if print_mem_usage:
    try:
        import psutil
    except ImportError:
        print("Could not load psutil, will not print mem usage.")
        print_mem_usage = False

print_interval = inargs['print_interval']
tau_w_list = inargs['tau_w_list']
pk_dir = inargs['pk_dir']
pk_init_dir = inargs['pk_init_dir']
save_pk_vecs = inargs['save_pk_vecs']

#%%

if datadir == '':
    raise Exception('datadir must be given')

if savedir == '':
    raise Exception('savedir must be given')

    
print('Arguments:')
for item in inargs.items():
    print(item)
#%%
all_files = os.listdir(datadir)

# get num nodes
num_nodes = None
i = 0

while num_nodes is None:
    compression = None
    if all_files[i][-3:] == '.gz':
        compression = 'gzip'
    res = pd.read_pickle(os.path.join(datadir, all_files[i]),
                         compression=compression)
    if isinstance(res, dict):
        num_nodes = res.get('num_nodes')
    i += 1

    
# extract intervals and tau_ws
intervals = set()
tau_ws = set()
for file in all_files:
    for extract in os.path.splitext(file)[0].split('_'):
        if re.search('int[0-9]{6}',extract):
            intervals.add(int(extract[3:]))
        elif extract.startswith('w'):
            tau_ws.add(float(extract[1:]))
            
intervals = sorted(list(intervals))
tau_ws = sorted(list(tau_ws))

if len(tau_w_list) > 0:
    tau_ws = tau_w_list
            
num_small_grid_points = len(intervals)

if len(int_list) == 0:

    if int_length is None:
        int_length = num_small_grid_points//num_points
    
    # we can go +1 over num_small_grid_points because the last interval is never loaded
    # we compute up to the begining of the last interval.
    all_inds = list(range(0,num_small_grid_points+1,int_length))
    

else:
    all_inds = sorted(int_list)
    
#%% load p1


    
def load_p_init(int_start, tau_w, lin, reverse_time):
    """ returns the pk vector at the start of `interval`, i.e. the last of the 
        previous interval.
        
        """
    
    if p1_file != '' and pk_init_dir == '':
        print('PID ', os.getpid(), f'getting p1 from {p1_file}')
        
        p1 = np.loadtxt(p1_file, dtype=np.float64)
        assert np.allclose(p1.sum(),1.0)
        return csr_matrix(p1)

    elif p1_file == '' and pk_init_dir != '':
        print('PID ', os.getpid(), f'getting p1 from {pk_init_dir} at interval {int_start} and tau_w {tau_w}')
        
        pk_filenames = os.path.join(pk_init_dir, net_name + \
                         f'_tau_w{tau_w:.3e}' + '_pk' + \
                         '_int{interval:06d}_')
                
        if lin:
            pk_filenames += 'lin_'
            
    
        if reverse_time:
            # pk at the begining of int_start is pk_int_start_slice0
            pk_file = pk_filenames.format(interval=int_start) + 'slice000000.hkl'
            
        else:
            # for the forward time case pk at the begining of int_start is pk_int(int_start-1)_slice(last slice)
            pk_file_minus_1 = sorted([f for f in \
                       os.listdir(os.path.dirname(pk_filenames.format(interval=int_start-1))) if \
                           os.path.basename(pk_filenames.format(interval=int_start-1)) in f])[-1]
        
            pk_file = os.path.join(os.path.dirname(pk_filenames),
                                            pk_file_minus_1)
        p_init = hickle.load(pk_file)
        
        assert np.allclose(p_init['pk'].sum(), 1.0)
        
        return p_init['pk']

    else:
        raise ValueError('Either p1_file or p_init_dir must be provided.')
        
#%% load trans mat
    
def load_trans_mat(k_range, tau_w, lin, reverse_time, tol, verbose=False):
    """loads, computes and returns the transition matrix computed from k_range[0]
        to k_range[-1] and the corresponding time duration.
        
        used to quickly start again an integral computation.
    """
    
    if reverse_time:
        time_direction_str = 'reversed_'
    else:
        time_direction_str = ''
        
    
    if lin:
        
        T_file = os.path.join(datadir, net_name + \
                         f'_tau_w{tau_w:.3e}' + '_int{k:06d}__' + \
                         time_direction_str + 'lin_trans_mat')
            
        
    else:
        T_file = os.path.join(datadir, net_name + \
                         f'_tau_w{tau_w:.3e}' + '_int{k:06d}__' + \
                         time_direction_str + 'trans_mat')
    

    
    
    T = eye(num_nodes, format='csr', dtype=np.float64)
    
    integr_time = 0
    
    for k in k_range:
        TM = ContTempNetwork.load_T(T_file.format(k=k))
        
        if lin:
            Tk = TM['T_lin'][1/tau_w][t_s]
        else:
            Tk = TM['T'][1/tau_w]
            
        set_to_zeroes(Tk, tol=tol)
        inplace_csr_row_normalize(Tk)
            
        T = sparse_matmul(T,Tk.tocsr(), verbose=verbose, log_message='T')
        set_to_zeroes(T, tol=tol)
        inplace_csr_row_normalize(T)
        
        integr_time += abs(TM['_t_stop_laplacians']-TM['_t_start_laplacians'])

    return T, integr_time

#%% worker

def worker(ind_start_tau_w):

    ind_start, tau_w, reverse_time, tol, integral_rtol = ind_start_tau_w
    
    # all_inds is a list of interval index that forms the large grid point
    # ind_start is the index in all_inds from where to start
    
    int_start = all_inds[ind_start]
    
    print('PID ', os.getpid(), f'interval start {int_start} for tau_w {tau_w}')
    
    inter_trans_file0 = os.path.join(datadir, net_name + \
                     f'_tau_w{tau_w:.3e}' + '_int{k:06d}_')
    
    # for forward int
    ITPT_file = os.path.join(savedir, net_name + \
                     f'_tau_w{tau_w:.3e}') + \
                          '_PTforw_{0:06d}_to_{1:06d}.hkl'

                                            
    ITPT_temp_file = os.path.join(savedir, net_name + \
                     f'_tau_w{tau_w:.3e}') + \
                          '_PTforw_temp_integration_{0:06d}_to_{1:06d}.hkl'
                          
    ITPT_temp_step_file = os.path.join(savedir, net_name + \
                     f'_tau_w{tau_w:.3e}') + \
                          '_PTforw_temp_step_integration_{0:06d}_to_{1:06d}_slice{2:06d}.hkl'                          
                    
    pk_file = os.path.join(pk_dir, net_name + \
                     f'_tau_w{tau_w:.3e}') + \
                          '_pk_int{0:06d}_slice{1:06d}.hkl'
                          
    # initialize matrices
    # I is the the sum of two integrals I = (1/T) * P1 (int TPT) P1 - p1p1 
    # where TPT = int_t1^t2 * T(t1,t) P(t)^-1 T(t1,t)^T ) dt
    # (we only need the int TPT part) called ITPT
    ITPT = csr_matrix((num_nodes,num_nodes), dtype=np.float64)

    Tk = eye(num_nodes, format='csr', dtype=np.float64)
            
    p1 = load_p_init(int_start, tau_w, False, reverse_time)
    
    _int_start = int_start
    
    if reverse_time:
        int_stops = all_inds[ind_start-1::-1]
        d = -1
        if only_one_interval:
            int_stops = [all_inds[ind_start-1]]
    else:
        int_stops = all_inds[ind_start+1:]
        if only_one_interval:
            int_stops = [all_inds[ind_start+1]]
        d = 1
    try:
        
        integration_time = 0
        
        # check if all files have already been computed if not loads already
        # existing files an update T and integration time accordingly
        all_expm_exists = False
        last_expm_exist = -1
        
        
        if use_expm_transmats:
            
            # check if some results already exists
            
            integrals_expm_exists = np.array([os.path.isfile(ITPT_file.format(int_start, int_stop)) or \
                                    os.path.isfile(ITPT_file.format(int_start, int_stop)+'.gz') \
                                      for int_stop in int_stops])
                
                
            if integrals_expm_exists.cumprod().nonzero()[0].size > 0:
                #find the longest continuous stretch of True
                last_expm_exist = integrals_expm_exists.cumprod().nonzero()[0].max()
                
                # if its all True (no need to compute anything)
                all_expm_exists = last_expm_exist +1 == integrals_expm_exists.size
                
                if not all_expm_exists:
                    #load the last integral
                    int_stop = int_stops[last_expm_exist]
                    
                    # check if Tk was saved in temp file
                    if os.path.isfile(ITPT_temp_file.format(int_start, int_stop)):
                        print('PID ', os.getpid(), ' loading I and Tk from ', ITPT_temp_file.format(int_start, int_stop))

                        ITPT_load = hickle.load(ITPT_temp_file.format(int_start, int_stop))
                        Tk = ITPT_load['Tk']
                        integration_time = ITPT_load['integration_time']
                        
                    elif os.path.isfile(ITPT_temp_step_file.format(int_start, int_stop, 0)):
                        print('PID ', os.getpid(), ' loading I and Tk from ', ITPT_temp_step_file.format(int_start, int_stop, 0))

                        # in this case, ITPT and Tk correspond in fact to the state after slice 0, but they
                        # will be correctly used because they will be reloaded below. Not optimal, but correct
                        # and avoid recomputing Tk.
                        ITPT_load = hickle.load(ITPT_temp_step_file.format(int_start, int_stop, 0))
                        Tk = ITPT_load['Tk']
                        integration_time = ITPT_load['integration_time_up_to_last_int']
                        
                        
                    else:
                        if os.path.isfile(ITPT_file.format(int_start, int_stop)):
                            ITPT_load = hickle.load(ITPT_file.format(int_start, int_stop))
                            print('PID ', os.getpid(), ' loading I from ', ITPT_file.format(int_start, int_stop))
                        else:
                            ITPT_load = hickle.load(ITPT_file.format(int_start, int_stop)+'.gz')
                            print('PID ', os.getpid(), ' loading I from ', ITPT_file.format(int_start, int_stop)+'.gz')
                            
                        if reverse_time:
                            load_range = range(int_start-1,int_stop-1,d)
                        else:
                            load_range = range(int_start,int_stop,d)
                        
                        print('PID ', os.getpid(), ' computing Tk for intervals', load_range)

                        Tk, integration_time = load_trans_mat(load_range, 
                                               tau_w, lin=False, reverse_time=reverse_time,
                                               tol=tol,
                                               verbose=verbose)
                    
                        assert integration_time == ITPT_load['integration_time']
                    
                    ITPT = ITPT_load['ITPT']

                    # update initial condition
                    _int_start = int_stops[last_expm_exist]
                    int_stops = int_stops[last_expm_exist+1:]
                    
                    del ITPT_load
                        
                    
        compute_expm = use_expm_transmats
        
        if all_expm_exists:
            # no need to compute expm int
            compute_expm = False
            print('PID ', os.getpid(), 
                  f' expm trans integral from int {int_start} for tau_w {tau_w} already computed')

                    
            
        if compute_expm:
            for int_stop in int_stops:
            
                t0 = time.time()
                
                print('PID ', os.getpid(), 
                      f' computing trans from int {int_start} to int {int_stop} for tau_w {tau_w}')
                

                if reverse_time:
                    k_range = range(_int_start-1,int_stop-1,d)
                else:
                    k_range = range(_int_start,int_stop,d)
                for k in k_range:
        
                    if verbose:
                        print('PID ', os.getpid(), 
                          f' -- k = {k} over {int_stop}')
                
                    
                    # load T list
                    if compute_expm:
                        
                        if save_intermediate and os.path.isfile(ITPT_temp_file.format(int_start, k+(not reverse_time))):
                            #load this step that has already been computed
                            
                            print('PID ', os.getpid(), ' loading temp step ', ITPT_temp_file.format(int_start, k+(not reverse_time)))
                            ITPT_temp_load = hickle.load(ITPT_temp_file.format(int_start, k+(not reverse_time)))
                            
                            assert ITPT_temp_load['last_treated_interval'] == k
                            
                            ITPT = ITPT_temp_load['ITPT']
                            Tk = ITPT_temp_load['Tk']
                            integration_time = ITPT_temp_load['integration_time']
                            
                            del ITPT_temp_load
                            
                        else:
                 
                            
                            inter_Ts = ContTempNetwork.load_inter_T(inter_trans_file0.format(k=k) + \
                                       '_inter_trans_mat')
                            
                            tl = time.time()
                            # number of slices
                            num_l = len(inter_Ts["inter_T"][1/tau_w])
                            
                            int_time_k = 0 
                            # integrate T 
                            for l, inter_Tk, dtk in zip(list(range(num_l))[::d],
                                                    inter_Ts['inter_T'][1/tau_w][::d],
                                                    d * np.diff(inter_Ts['times_k_start_to_k_stop+1'][::d])):
                                
    
                                if save_each_step and os.path.isfile(ITPT_temp_step_file.format(int_start, k, l)):
                                    #load this step that has already been computed
                                    
                                    print('PID ', os.getpid(), ' loading temp step ', ITPT_temp_step_file.format(int_start, k, l))
                                    ITPT_temp_step_load = hickle.load(ITPT_temp_step_file.format(int_start, k, l))
                                    
                                    assert ITPT_temp_step_load['last_treated_interval'] == k
                                    
                                    ITPT = ITPT_temp_step_load['ITPT']
                                    Tk = ITPT_temp_step_load['Tk']
                                    
                                    del ITPT_temp_step_load        
                                    
                                else:
                                    # compute this step
                                    if verbose:
                                        print('PID ', os.getpid(), f' computing interval {k}, slice {l}')

                                    
                                    set_to_zeroes(inter_Tk, tol=tol)
                                    inplace_csr_row_normalize(inter_Tk)
                                    
                                    Tk = sparse_matmul(Tk, inter_Tk.tocsr(), 
                                                       verbose=verbose,
                                                       log_message='Tk')
                                    set_to_zeroes(Tk, tol=tol)
                                    inplace_csr_row_normalize(Tk)
                                                                         
                                
                                    if os.path.isfile(pk_file.format(k,l)):
                                        if verbose:
                                            print('PID ', os.getpid(), 'loading', pk_file.format(k,l) )
                                        
                                        phickle = hickle.load(pk_file.format(k,l))
                                        
                                        pk = phickle['pk']
                                    
                                        del phickle
                                    else:
                                        # compute pk
                                        pk = sparse_matmul(p1, Tk.tocsr(), 
                                                           verbose=verbose,
                                                           log_message='pk')
                                        
                                
                                        if save_pk_vecs:
                                            print('PID ', os.getpid(), 'saving pk to', pk_file.format(k,l))
                                            
                                            hickle.dump({'pk':pk, 'dt': dtk, 
                                                         'interval':k, 
                                                         'slice' : l,
                                                         'T_tol' : tol}, pk_file.format(k,l))
                                    
                                    # in order to avoid nan in Ik due to 0 * np.inf 
                                    pk_arr = pk.toarray()
                                    pk_arr[np.where(pk_arr == 0)] = 1
                                    
                                    # we do (Tk @ Pk^-1/2) @ (Tk @ Pk^-1/2)^T for ITPTk
                                    # and (P1^1/2 @ Tk)^T @ (P1^1/2 @ Tk)
                                    
                                    ITPTk = Tk.copy().tocsr()
                                    
                                    inplace_csr_matmul_diag(ITPTk,np.sqrt(1/pk_arr))
                                
                           
                                    while True:
                                        try:
                                            # set_to_zeroes(ITPTk, integral_rtol)
                                            ITPTk = sparse_gram_matrix(ITPTk, transpose=True,
                                                                               verbose=verbose,
                                                                               log_message='ITPTk')
                                        except ValueError as e:
                                            if e.args[0][-28:] == '(SPARSE_STATUS_ALLOC_FAILED)':
                                                print('PID ', os.getpid(), f' sparse_gram_matrix ITPTk ALLOC FAILED, nnz={ITPTk.nnz}, increasing tol to {10*integral_rtol}')
                                                integral_rtol *= 1.1
                                                tol *= 1.1
                                                continue
                                            else:
                                                raise e
                                        break
                                    
                                   
                                    ITPTk.data *=  dtk # operating on data avoids making a copy here.
                                    
                                    # left and right mutliply by P1, will keep sparsity high
                                    inplace_diag_matmul_csr(ITPTk, p1.toarray())
                                    inplace_csr_matmul_diag(ITPTk, p1.toarray())
                                    ITPTk.eliminate_zeros()
                                    
                                    if (p1.toarray() > 0).sum() > 5000:
                                        # multiplying ITPTk by P1 is not enough to keep it sparse
                                        set_to_zeroes(ITPTk, integral_rtol)
                                    
                                    
                                    ITPT = ITPT + ITPTk
                                
                                    if (p1.toarray() > 0).sum() > 5000:
                                        # multiplying ITPTk by P1 is not enough to keep it sparse
                                        set_to_zeroes(ITPT, integral_rtol)
                                
                                
                                    int_time_k += dtk
                                    
                                    if verbose:
                                        if not l%print_interval:
                                            print('PID ', os.getpid(), f' -- k = {k} over {k_range[-1]}, integrating {l} over {num_l},',
                                                  f'took {time.time()-tl:0.3f},', 
                                                  f'\n, ITPT nnz = {ITPT.nnz},',
                                                  f', ITPT size (GB) = {(ITPT.data.nbytes + ITPT.indptr.nbytes + ITPT.indices.nbytes)/1024**3:0.6f}.',
                                                  f'\n, pk nnz = {pk.nnz},',
                                                  f'\n, Tk nnz = {Tk.nnz},',
                                                  f', Tk size (GB) = {(Tk.data.nbytes + Tk.indptr.nbytes + Tk.indices.nbytes)/1024**3:0.6f}.',
                                                  )
                                            if print_mem_usage:
                                                minf = psutil.virtual_memory()
                                                print('\nPID ', os.getpid(), f'Memory info (GB): used {minf.used/1024**3:0.3f} ({minf.percent}%), available {minf.available/1024**3:0.3f}, active {minf.active/1024**3:0.3f}, inactive {minf.inactive/1024**3:0.3f}, buffers {minf.buffers/1024**3:0.3f}' )
                                            tl = time.time()
                                            
                                            
                                            print('PID ', os.getpid(),
                                                      f', ITPT sum slice {(ITPT.data.sum()*2 - ITPT.diagonal().sum())*(1/(integration_time+int_time_k))}')
                                            
                                    
                                    del ITPTk
                                    # del IpTpk
                                    gc.collect()
                                    
                                    if save_each_step:
                                        print('PID ', os.getpid(), ' saving intermediate results to ', 
                                              ITPT_temp_step_file.format(int_start,  k, l))
                                        
                                        hickle.dump({'ITPT' : ITPT,
                                              'interval_start' : int_start,
                                              'last_treated_interval' : k,
                                              'last_treated_slice' : l,
                                              'Tk': Tk,
                                              'integration_time_up_to_last_int' : integration_time,
                                              'tol': tol,
                                              'integral_rtol' : integral_rtol},
                                            ITPT_temp_step_file.format(int_start, k,l))                                    
                                
                            
                                            
                            integration_time += abs(inter_Ts['_t_stop_laplacians']-inter_Ts['_t_start_laplacians'])
                            
                            # should be close to 1
                            ITPT_sum = (ITPT.data.sum()*2 - ITPT.diagonal().sum())*(1/integration_time)
                            if verbose:
                                print('PID ', os.getpid(),
                                      f', ITPT sum {ITPT_sum}')
                                
                            del inter_Ts
                            gc.collect()
                            
                            if save_intermediate:
                                if integration_time > 0:
                                    print('PID ', os.getpid(), ' saving intermediate results to ', 
                                          ITPT_temp_file.format(int_start,  k+1))
                                    
                                    hickle.dump({'ITPT' : ITPT,
                                          'interval_start' : int_start,
                                          'last_treated_interval' : k,
                                          'integration_time' : integration_time,
                                          'Tk': Tk,
                                          'tol': tol,
                                          'integral_rtol' : integral_rtol,
                                          'ITPT_sum' : ITPT_sum},
                                        ITPT_temp_file.format(int_start, k+1))
                            

                # saving results
                if compute_expm:
                    if integration_time > 0:
                        print('PID ', os.getpid(), ' saving to ', ITPT_file.format(int_start, int_stop))
                        
                        hickle.dump({'ITPT' : ITPT,
                                      'p1' : p1,
                                      'pk' : pk,
                                      'integration_time' : integration_time,
                                      'integral_rtol' : integral_rtol,
                                      'tol': tol}, 
                                    ITPT_file.format(int_start, int_stop), 'w')
                    else:
                        print('PID ', os.getpid(), f' integration time is zero, not saving {int_start} to {int_stop}')
    

                    
                t1 = time.time()
                print('PID ', os.getpid(), 'finished in {:.2f}'.format(t1 - t0) )
                
                _int_start = int_stop
                

        del ITPT
        del Tk
        gc.collect()
    
    except Exception:
        print('PID ', os.getpid(), '-+-+-+ Exception at int_start=', int_start, 
              ' int_stop=', int_stop, ' tau_w=', tau_w, 'ind_start=', ind_start,
              file=sys.stdout)
        print('PID ', os.getpid(), '-+-+-+ Exception at int_start=', int_start, 
              ' int_stop=', int_stop, ' tau_w=', tau_w, 'ind_start=', ind_start,
               file=sys.stderr)
        
        traceback.print_exc(file=sys.stderr)
      
            
            
        
#%%

# combination of ind_start and all_inds
if __name__ == '__main__':        
    t00 = time.time()
                
    ind_starts_tau_ws = []


    if reverse_time:
        # reverse time means backward time and forward diffusion
        ind_starts_tau_ws.extend([(len(all_inds)-1, tau_w, True, tol, integral_rtol) for tau_w in tau_ws])
    else:
        # forward time, forward diffusion
        ind_starts_tau_ws.extend([(0, tau_w, False, tol, integral_rtol) for tau_w in tau_ws])
  
    print(ind_starts_tau_ws)
    
    print('starting pool of {0} cpus'.format(ncpu))
    with Pool(ncpu) as p:
        work = p.map_async(worker, ind_starts_tau_ws)
        data = work.get()
            
        
    print('***** Finished! in {:.2f}'.format(time.time()-t00))
